
#### HOUSEKEEPING ####

rm(list = ls())

## PACKAGES
# install_github("cran/LowRankQP")
package_list <- c("MASS", "ggplot2", "gtable", "grid", "reshape2", "doParallel", "data.table", "tidyverse", "LowRankQP")
lapply(package_list, require, character.only = TRUE)

## PARAMS
n_job   <- 400
date   <- "final"
n_s     <- 200 # make sure n_s divides n_job
j_per_s <- n_job/n_s

## DIRECTORIES
source("00_paths.R")
out_dir <- paste0(res_dir, "pensynth/", date, "/")
lambda <- fread(file = paste0(res_dir, "/pensynth/", date, "/inputs/lambda.csv"))$lambda
l_print <- format(lambda, scientific = TRUE)
if (!dir.exists(paste0(out_dir, "fixed/", l_print, "/"))) {
  dir.create(paste0(out_dir, "fixed/", l_print, "/"))
  dir.create(paste0(out_dir, "fixed/", l_print, "_", "SE", "/"))
  for (is in 1:n_s) {
    dir.create(paste0(out_dir, "fixed/", l_print, "_", "SE", "/", sprintf("%03d", is), "/"))
  }
}

#### FUNCTIONS ####

## PENSYNTH FUNCTIONS
fun_list <- c("wsoll1", "regsynth", "regsynthpath", "pensynth_parallel", "TZero", "estimator_matching", "get_stats")
fun_list <- paste0(ps_dir, "/functions/", fun_list, ".R")
lapply(fun_list, source)

#### ~ GROUP SETUP FUNCTION ####
format_group <- function(df, group, X_vars, y_var, treat_var) {
  
  df <- df[df$group==group,]
  
  d           <- df[[treat_var]]
  y           <- df[[y_var]]
  X           <- data.frame(df[, ..X_vars])
  X_unscaled  <- X
  X[, X_vars] <- mapply(function(x) X[,x]/sd(X[d==1,x]), X_vars) # rescale
  
  X  <- as.matrix(X)
  X0 <- t(X[d==0,])
  X1 <- t(X[d==1,])
  Y0 <- y[d==0]
  Y1 <- y[d==1]
  V  <- diag(ncol(X))
  
  X0_unique  <- as.data.table(cbind(Y0,t(X0)))
  X0_unique  <- X0_unique[,list(Y0_average = mean(Y0)), X_vars]
  Y0_average <- as.vector(X0_unique[,Y0_average])
  X0_unique  <- t(as.matrix(X0_unique[,..X_vars]))
  
  X0_ids <- df$customer_key[d==0]
  X1_ids <- df$customer_key[d==1]
  
  return(list(d = d,
              y = y,
              X = X,
              X_unscaled = X_unscaled,
              X0 = X0,
              X1 = X1,
              Y0 = Y0,
              Y1 = Y1,
              V = V,
              X0_unique = X0_unique,
              Y0_average = Y0_average,
              X0_ids = X0_ids,
              X1_ids = X1_ids))
  
}

#### SET UP ####

## JOB ARRAY

task_id     <- as.integer(Sys.getenv("SLURM_ARRAY_TASK_ID"))
task_id     <- ifelse(is.na(task_id), 200, task_id)
cur_s       <- floor(task_id/j_per_s - 0.00001) + 1

## LOAD DATA
df     <- fread(file = paste0(out_dir, "inputs/resamp/data_", cur_s, ".csv"))
X_vars <- fread(file = paste0(out_dir, "inputs/X_vars.csv"))$X_vars

groups <- sort(unique(df$group))
setup  <- list()

grid_df <- data.frame(
  Group = as.numeric(),
  Index = as.numeric()
)

for (g in groups) {
  setup[[g]] <- format_group(df = df[df$group==g,], 
                             X_vars = X_vars, 
                             y_var = "mean_gb_post", 
                             treat_var = "treat")
  group_df   <- data.frame(
    Group = g,
    Index = 1:length(setup[[g]]$Y1)
  )
  grid_df    <- rbind(grid_df, group_df)
  rm(group_df)
}

n_grid      <- nrow(grid_df)
n_per_job   <- ceiling(n_grid/j_per_s)
cur_s_st    <- (cur_s - 1) * j_per_s + 1
cur_s_en    <- cur_s * j_per_s
job_start   <- (task_id - cur_s_st)*n_per_job + 1
job_end     <- min((task_id - cur_s_st + 1) * n_per_job, n_grid)
job_id_list <- job_start:job_end

if (job_start>n_grid) {job_id_list <- NULL}

#### ESTIMATION ####

## LOOP OVER IDS
for (id in job_id_list) {
  
  g <- grid_df$Group[id]
  k <- grid_df$Index[id]
  l <- lambda
  
  fname <- paste0(out_dir, "fixed/", l_print, "_", "SE", "/", sprintf("%03d", cur_s), "/", g, "_", k, ".csv")
  print(fname)

  X0_unique  <- setup[[g]]$X0_unique
  X1k        <- setup[[g]]$X1[,k]
  Y0_average <- setup[[g]]$Y0_average
  Y1k        <- setup[[g]]$Y1[k]
  V          <- setup[[g]]$V
  
  t_start <- Sys.time()
  sol     <- wsoll1(X0=X0_unique, X1=X1k, V=V, pen=l)
  sol     <- TZero(sol)  
  print(Sys.time()-t_start)
  
  fwrite(x = data.frame(sol), file = fname)
  
}
